In [ ]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils import spectral_norm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader,Subset
import matplotlib.pyplot as plt
import numpy as np
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
import glob
import matplotlib.image as mpimg
In [ ]:
# ----------------------------------------------------------
# Device
# ----------------------------------------------------------
device = torch.device("mps" if torch.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device {device}')
Using device mps
In [ ]:
# ----------------------------------------------------------
# Hyperparameters (Complete SAGAN)
# ----------------------------------------------------------
EPOCHS = 550
BATCH_SIZE = 128
IMAGE_SIZE = 32
CHANNELS_IMG = 3
LATENT_DIM = 128
EMBED_DIM = 50
GEN_LR = 1e-4
DISC_LR = 4e-4
BETA1, BETA2 = 0.0, 0.9
CHECKPOINT_EVERY = 20
AUTOMOBILE_CLASS_IDX = 1
In [ ]:
# ----------------------------------------------------------
# Self-Attention Block
# ----------------------------------------------------------
class SelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.query = spectral_norm(nn.Conv2d(in_channels, in_channels // 8, 1))
self.key = spectral_norm(nn.Conv2d(in_channels, in_channels // 8, 1))
self.value = spectral_norm(nn.Conv2d(in_channels, in_channels, 1))
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
b, c, h, w = x.size()
query_out = self.query(x).view(b, -1, w*h) # (b, c//8, h*w)
key_out = self.key(x).view(b, -1, w*h) # (b, c//8, h*w)
attn = torch.bmm(query_out.permute(0, 2, 1), key_out) # (b, h*w, h*w)
attn = torch.softmax(attn, dim=-1)
value_out = self.value(x).view(b, c, w*h) # (b, c, h*w)
out = torch.bmm(value_out, attn.permute(0, 2, 1)) # (b, c, h*w)
out = out.view(b, c, h, w)
return self.gamma * out + x
In [ ]:
# ----------------------------------------------------------
# CIFAR-10 Data Loading
# ----------------------------------------------------------
transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
automobile_indices = [i for i, (_, label) in enumerate(trainset) if label == AUTOMOBILE_CLASS_IDX]
automobile_dataset = Subset(trainset, automobile_indices)
# Create dataloader with only automobile images
trainloader = DataLoader(
automobile_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=2
)
In [ ]:
# ----------------------------------------------------------
# Generator (SAGAN with Spectral Norm)
# ----------------------------------------------------------
class Generator(nn.Module):
def __init__(self, latent_dim, embed_dim, num_classes=10):
super().__init__()
self.label_emb = nn.Embedding(num_classes, embed_dim)
self.init_fc = nn.Sequential(
spectral_norm(nn.Linear(latent_dim + embed_dim, 4*4*512)),
nn.BatchNorm1d(4*4*512),
nn.ReLU(True)
)
self.conv_blocks = nn.Sequential(
spectral_norm(nn.ConvTranspose2d(512, 256, 4, 2, 1)),
nn.BatchNorm2d(256),
nn.ReLU(True),
SelfAttention(256),
spectral_norm(nn.ConvTranspose2d(256, 128, 4, 2, 1)),
nn.BatchNorm2d(128),
nn.ReLU(True),
spectral_norm(nn.ConvTranspose2d(128, CHANNELS_IMG, 4, 2, 1)),
nn.Tanh()
)
def forward(self, z, labels):
emb = self.label_emb(labels)
x = torch.cat([z, emb], dim=1)
x = self.init_fc(x).view(-1, 512, 4, 4)
return self.conv_blocks(x)
In [ ]:
# ----------------------------------------------------------
# Discriminator (SAGAN with Spectral Norm)
# ----------------------------------------------------------
class Discriminator(nn.Module):
def __init__(self, embed_dim, num_classes=10):
super().__init__()
self.label_emb = nn.Embedding(num_classes, embed_dim)
self.conv_blocks = nn.Sequential(
spectral_norm(nn.Conv2d(CHANNELS_IMG, 64, 4, 2, 1)),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv2d(64, 128, 4, 2, 1)),
nn.LeakyReLU(0.2, inplace=True),
SelfAttention(128),
spectral_norm(nn.Conv2d(128, 256, 4, 2, 1)),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv2d(256, 512, 4, 2, 1)),
nn.LeakyReLU(0.2, inplace=True),
)
self.fc = spectral_norm(nn.Linear(512*2*2 + embed_dim, 1))
def forward(self, x, labels):
bsz = x.size(0)
emb = self.label_emb(labels)
features = self.conv_blocks(x).view(bsz, -1)
combined = torch.cat([features, emb], dim=1)
return self.fc(combined)
In [ ]:
# ----------------------------------------------------------
# Initialize Model, Loss, Optimizers
# ----------------------------------------------------------
gen = Generator(LATENT_DIM, EMBED_DIM).to(device)
disc = Discriminator(EMBED_DIM).to(device)
criterion = nn.BCEWithLogitsLoss()
opt_gen = optim.Adam(gen.parameters(), lr=GEN_LR, betas=(BETA1, BETA2))
opt_disc = optim.Adam(disc.parameters(), lr=DISC_LR, betas=(BETA1, BETA2))
checkpoint_path = "adl_part_4.pt"
start_epoch = 1
In [ ]:
# ----------------------------------------------------------
# Check for Existing Checkpoint
# ----------------------------------------------------------
if os.path.exists(checkpoint_path):
ckpt = torch.load(checkpoint_path, map_location=device)
gen.load_state_dict(ckpt["gen_state_dict"])
disc.load_state_dict(ckpt["disc_state_dict"])
opt_gen.load_state_dict(ckpt["opt_gen_state_dict"])
opt_disc.load_state_dict(ckpt["opt_disc_state_dict"])
start_epoch = ckpt["epoch"] + 1
In [ ]:
# ----------------------------------------------------------
# Utility: Generate & Show 10 Samples
# ----------------------------------------------------------
def generate_and_show_samples(epoch):
gen.eval()
with torch.no_grad():
z = torch.randn(10, LATENT_DIM, device=device)
labels = torch.full((10,), AUTOMOBILE_CLASS_IDX, dtype=torch.long, device=device)
samples = gen(z, labels).cpu()
samples = (samples + 1) / 2.0
fig, axes = plt.subplots(1, 10, figsize=(22, 2.4))
for i in range(10):
img = samples[i].permute(1, 2, 0).numpy()
axes[i].imshow(img)
axes[i].axis('off')
plt.suptitle(f"Epoch {epoch}: SAGAN Samples (Automobile)", fontsize=14)
plt.savefig(f'task4/automobile_gan_losses_{epoch}.png')
plt.show()
gen.train()
In [ ]:
# ----------------------------------------------------------
# Compute IS & FID
# ----------------------------------------------------------
def compute_is_fid(generator, loader, n_samples=2000):
is_metric = InceptionScore().to("cpu")
fid_metric = FrechetInceptionDistance().to("cpu")
generator.eval()
real_count = 0
for real_imgs, _ in loader:
real_imgs = real_imgs.to(device)
real_imgs_uint8 = (((real_imgs * 0.5) + 0.5) * 255).to(torch.uint8).cpu()
fid_metric.update(real_imgs_uint8, real=True)
real_count += real_imgs.size(0)
if real_count >= n_samples:
break
fake_count = 0
while fake_count < n_samples:
z = torch.randn(BATCH_SIZE, LATENT_DIM, device=device)
labels = torch.randint(0, 10, (BATCH_SIZE,), dtype=torch.long, device=device)
with torch.no_grad():
fake_out = generator(z, labels)
fake_out_uint8 = (((fake_out * 0.5) + 0.5) * 255).to(torch.uint8).cpu()
is_metric.update(fake_out_uint8)
fid_metric.update(fake_out_uint8, real=False)
fake_count += BATCH_SIZE
inception_score = is_metric.compute() # (mean, std)
fid_score = fid_metric.compute()
generator.train()
return inception_score[0].item(), fid_score.item()
In [ ]:
# ----------------------------------------------------------
# Training Loop
# ----------------------------------------------------------
for epoch in range(start_epoch, EPOCHS + 1):
for _, (real, labels) in enumerate(trainloader):
real, labels = real.to(device), labels.to(device)
bsz = real.size(0)
# --------------------
# Train Discriminator
# --------------------
disc.zero_grad()
noise = torch.randn(bsz, LATENT_DIM, device=device)
rand_labels = torch.randint(0, 10, (bsz,), dtype=torch.long, device=device)
pred_real = disc(real, labels)
loss_real = criterion(pred_real, torch.ones_like(pred_real))
fake = gen(noise, rand_labels)
pred_fake = disc(fake.detach(), rand_labels)
loss_fake = criterion(pred_fake, torch.zeros_like(pred_fake))
lossD = loss_real + loss_fake
lossD.backward()
opt_disc.step()
# ----------------
# Train Generator
# ----------------
gen.zero_grad()
pred_gen = disc(fake, rand_labels)
lossG = criterion(pred_gen, torch.ones_like(pred_gen))
lossG.backward()
opt_gen.step()
print(f"[Epoch {epoch}/{EPOCHS}] LossD: {lossD.item():.4f} LossG: {lossG.item():.4f}")
if epoch % CHECKPOINT_EVERY == 0:
save_data = {
"epoch": epoch,
"gen_state_dict": gen.state_dict(),
"disc_state_dict": disc.state_dict(),
"opt_gen_state_dict": opt_gen.state_dict(),
"opt_disc_state_dict": opt_disc.state_dict()
}
torch.save(save_data, checkpoint_path)
print(f"[epoch={epoch}]Checkpoint saved: {checkpoint_path}")
generate_and_show_samples(epoch)
is_val, fid_val = compute_is_fid(gen, trainloader)
print(f"==> Epoch {epoch}: Inception Score = {is_val:.4f}, FID = {fid_val:.4f}")
print("Training complete!")
[Epoch 501/550] LossD: 0.2583 LossG: 8.0771 [Epoch 502/550] LossD: 0.0396 LossG: 11.6219 [Epoch 503/550] LossD: 0.2819 LossG: 11.0093 [Epoch 504/550] LossD: 0.0392 LossG: 8.3315 [Epoch 505/550] LossD: 0.1530 LossG: 7.2564 [Epoch 506/550] LossD: 0.2484 LossG: 8.0930 [Epoch 507/550] LossD: 0.1423 LossG: 6.9417 [Epoch 508/550] LossD: 0.1764 LossG: 7.5808 [Epoch 509/550] LossD: 0.2741 LossG: 8.0010 [Epoch 510/550] LossD: 0.0762 LossG: 8.3668 [Epoch 511/550] LossD: 0.2209 LossG: 9.2348 [Epoch 512/550] LossD: 0.0600 LossG: 7.1931 [Epoch 513/550] LossD: 0.1417 LossG: 10.0495 [Epoch 514/550] LossD: 0.4248 LossG: 7.6931 [Epoch 515/550] LossD: 0.1376 LossG: 6.9469 [Epoch 516/550] LossD: 0.7187 LossG: 15.0068 [Epoch 517/550] LossD: 0.3110 LossG: 8.2563 [Epoch 518/550] LossD: 0.0804 LossG: 6.7781 [Epoch 519/550] LossD: 0.1688 LossG: 7.4880 [Epoch 520/550] LossD: 0.2271 LossG: 8.3182 [epoch=520]Checkpoint saved: adl_part_4.pt
/Users/shivamsahil/Downloads/bits/assignments/venv/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `InceptionScore` will save all extracted features in buffer. For large datasets this may lead to large memory footprint. warnings.warn(*args, **kwargs) # noqa: B028
==> Epoch 520: Inception Score = 1.6597, FID = 343.0216 [Epoch 521/550] LossD: 0.1547 LossG: 8.0558 [Epoch 522/550] LossD: 0.1291 LossG: 6.8157 [Epoch 523/550] LossD: 0.1425 LossG: 10.0711 [Epoch 524/550] LossD: 0.2498 LossG: 7.4953 [Epoch 525/550] LossD: 0.2674 LossG: 5.8651 [Epoch 526/550] LossD: 0.3667 LossG: 9.8780 [Epoch 527/550] LossD: 0.0309 LossG: 7.9693 [Epoch 528/550] LossD: 0.1494 LossG: 8.0814 [Epoch 529/550] LossD: 0.2831 LossG: 5.3990 [Epoch 530/550] LossD: 0.3526 LossG: 7.0946 [Epoch 531/550] LossD: 0.6136 LossG: 16.5109 [Epoch 532/550] LossD: 0.0501 LossG: 10.2105 [Epoch 533/550] LossD: 0.3019 LossG: 6.9905 [Epoch 534/550] LossD: 0.3236 LossG: 7.8011 [Epoch 535/550] LossD: 0.4857 LossG: 8.6465 [Epoch 536/550] LossD: 0.3543 LossG: 8.8010 [Epoch 537/550] LossD: 0.0894 LossG: 9.2096 [Epoch 538/550] LossD: 0.3729 LossG: 6.4945 [Epoch 539/550] LossD: 0.0231 LossG: 8.9839 [Epoch 540/550] LossD: 0.3395 LossG: 6.9768 [epoch=540]Checkpoint saved: adl_part_4.pt
==> Epoch 540: Inception Score = 1.6411, FID = 363.5405 [Epoch 541/550] LossD: 0.1517 LossG: 7.6274 [Epoch 542/550] LossD: 0.1124 LossG: 15.2578 [Epoch 543/550] LossD: 0.1488 LossG: 8.4764 [Epoch 544/550] LossD: 0.3222 LossG: 6.4993 [Epoch 545/550] LossD: 0.3849 LossG: 11.2570 [Epoch 546/550] LossD: 0.0509 LossG: 11.5791 [Epoch 547/550] LossD: 0.4406 LossG: 8.0726 [Epoch 548/550] LossD: 0.0303 LossG: 7.4664 [Epoch 549/550] LossD: 0.0429 LossG: 8.0287 [Epoch 550/550] LossD: 0.2678 LossG: 10.0007 Training complete!
In [ ]:
directory = r'task4'
# Define a custom sort key that extracts the epoch number
def extract_epoch(filename):
base = os.path.basename(filename)
try:
epoch_str = base.split('automobile_gan_losses_')[1].split('.')[0]
return int(epoch_str)
except (IndexError, ValueError):
return float('inf')
png_files = glob.glob(os.path.join(directory, '*.png'))
# Sort the list numerically by epoch number
png_files = sorted(png_files, key=extract_epoch)
# Check if any PNG files are found
if not png_files:
print("No PNG files found in the directory:", directory)
else:
n = len(png_files)
# Increase the figure size to accommodate full screen-like display
fig, axs = plt.subplots(n, 1, figsize=(22, 2.4 * n))
# If only one image, wrap axs into a list for consistency
if n == 1:
axs = [axs]
mng = plt.get_current_fig_manager()
try:
mng.window.state('zoomed')
except AttributeError:
try:
mng.window.showMaximized()
except Exception:
pass # If it fails, the figure will remain at the set figsize
# Loop through each file and display the image
for ax, file in zip(axs, png_files):
img = mpimg.imread(file)
ax.imshow(img, aspect='auto')
ax.axis('off')
ax.set_title(os.path.basename(file), fontsize=14)
plt.tight_layout()
plt.show()
In [ ]:
# Install necessary packages
!apt-get install texlive texlive-xetex texlive-latex-extra pandoc
!pip install pypandoc
# Mount Google Drive
from google.colab import drive
drive.mount("/content/drive", force_remount=True)
# Copy the notebook to the current directory
!cp 'drive/My Drive/Colab Notebooks/Assignment2_Group75_Task4.ipynb' ./
# Convert the notebook to PDF while keeping the code and output
!jupyter nbconvert --to html "Assignment2_Group75_Task4.ipynb"
# Download the generated PDF
from google.colab import files
files.download('Assignment2_Group75_Task4.html')